{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp optuna"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optuna"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">A hyperparameter optimization framework\n",
    "\n",
    "Optuna is an automatic hyperparameter optimization software framework, particularly designed for machine learning. It features an imperative, define-by-run style user API. Thanks to our define-by-run API, the code written with Optuna enjoys high modularity, and the user of Optuna can dynamically construct the search spaces for the hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from pathlib import Path\n",
    "from fastcore.script import *\n",
    "import joblib\n",
    "from importlib import import_module\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|exports\n",
    "def run_optuna_study(objective, resume=None, study_type=None, multivariate=True, search_space=None, evaluate=None, seed=None, sampler=None, pruner=None, \n",
    "                     study_name=None, direction='maximize', n_trials=None, timeout=None, gc_after_trial=False, show_progress_bar=True, \n",
    "                     save_study=True, path='optuna', show_plots=True):\n",
    "    r\"\"\"Creates and runs an optuna study.\n",
    "\n",
    "    Args: \n",
    "        objective:          A callable that implements objective function.\n",
    "        resume:             Path to a previously saved study.\n",
    "        study_type:         Type of study selected (bayesian, gridsearch, randomsearch). Based on this a sampler will be build if sampler is None. \n",
    "                            If a sampler is passed, this has no effect.\n",
    "        multivariate:       If this is True, the multivariate TPE is used when suggesting parameters. The multivariate TPE is reported to outperform \n",
    "                            the independent TPE.\n",
    "        search_space:       Search space required when running a gridsearch (if you don't pass a sampler).\n",
    "        evaluate:           Allows you to pass a specific set of hyperparameters that will be evaluated.\n",
    "        seed:               Fixed seed used by samplers.\n",
    "        sampler:            A sampler object that implements background algorithm for value suggestion. If None is specified, TPESampler is used during \n",
    "                            single-objective optimization and NSGAIISampler during multi-objective optimization. See also samplers.\n",
    "        pruner:             A pruner object that decides early stopping of unpromising trials. If None is specified, MedianPruner is used as the default. \n",
    "                            See also pruners.\n",
    "        study_name:         Study’s name. If this argument is set to None, a unique name is generated automatically.\n",
    "        direction:          A sequence of directions during multi-objective optimization.\n",
    "        n_trials:           The number of trials. If this argument is set to None, there is no limitation on the number of trials. If timeout is also set to \n",
    "                            None, the study continues to create trials until it receives a termination signal such as Ctrl+C or SIGTERM.\n",
    "        timeout:            Stop study after the given number of second(s). If this argument is set to None, the study is executed without time limitation. \n",
    "                            If n_trials is also set to None, the study continues to create trials until it receives a termination signal such as \n",
    "                            Ctrl+C or SIGTERM.\n",
    "        gc_after_trial:     Flag to execute garbage collection at the end of each trial. By default, garbage collection is enabled, just in case. \n",
    "                            You can turn it off with this argument if memory is safely managed in your objective function.\n",
    "        show_progress_bar:  Flag to show progress bars or not. To disable progress bar, set this False.\n",
    "        save_study:         Save your study when finished/ interrupted.\n",
    "        path:               Folder where the study will be saved.\n",
    "        show_plots:         Flag to control whether plots are shown at the end of the study.\n",
    "    \"\"\"\n",
    "    \n",
    "    try: import optuna\n",
    "    except ImportError: raise ImportError('You need to install optuna to use run_optuna_study')\n",
    "\n",
    "    # Sampler\n",
    "    if sampler is None:\n",
    "        if study_type is None or \"bayes\" in study_type.lower(): \n",
    "            sampler = optuna.samplers.TPESampler(seed=seed, multivariate=multivariate)\n",
    "        elif \"grid\" in study_type.lower():\n",
    "            assert search_space, f\"you need to pass a search_space dict to run a gridsearch\"\n",
    "            sampler = optuna.samplers.GridSampler(search_space)\n",
    "        elif \"random\" in study_type.lower(): \n",
    "            sampler = optuna.samplers.RandomSampler(seed=seed)\n",
    "    assert sampler, \"you need to either select a study type (bayesian, gridsampler, randomsampler) or pass a sampler\"\n",
    "\n",
    "    # Study\n",
    "    if resume: \n",
    "        try:\n",
    "            study = joblib.load(resume)\n",
    "        except: \n",
    "            print(f\"joblib.load({resume}) couldn't recover any saved study. Check the path.\")\n",
    "            return\n",
    "        print(\"Best trial until now:\")\n",
    "        print(\" Value: \", study.best_trial.value)\n",
    "        print(\" Params: \")\n",
    "        for key, value in study.best_trial.params.items():\n",
    "            print(f\"    {key}: {value}\")\n",
    "    else: \n",
    "        study = optuna.create_study(sampler=sampler, pruner=pruner, study_name=study_name, direction=direction)\n",
    "    if evaluate: study.enqueue_trial(evaluate)\n",
    "    try:\n",
    "        study.optimize(objective, n_trials=n_trials, timeout=timeout, gc_after_trial=gc_after_trial, show_progress_bar=show_progress_bar)\n",
    "    except KeyboardInterrupt:\n",
    "        pass\n",
    "\n",
    "    # Save\n",
    "    if save_study:\n",
    "        full_path = Path(path)/f'{study.study_name}.pkl'\n",
    "        full_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "        joblib.dump(study, full_path)\n",
    "        print(f'\\nOptuna study saved to {full_path}')\n",
    "        print(f\"To reload the study run: study = joblib.load('{full_path}')\")\n",
    "\n",
    "    # Plots\n",
    "    if show_plots and len(study.trials) > 1:\n",
    "        try: display(optuna.visualization.plot_optimization_history(study))\n",
    "        except: pass\n",
    "        try: display(optuna.visualization.plot_param_importances(study))\n",
    "        except: pass\n",
    "        try: display(optuna.visualization.plot_slice(study))\n",
    "        except: pass\n",
    "        try: display(optuna.visualization.plot_parallel_coordinate(study))\n",
    "        except: pass\n",
    "\n",
    "    # Study stats\n",
    "    try:\n",
    "        pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]\n",
    "        complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]\n",
    "        print(f\"\\nStudy statistics    : \")\n",
    "        print(f\"  Study name        : {study.study_name}\")\n",
    "        print(f\"  # finished trials : {len(study.trials)}\")\n",
    "        print(f\"  # pruned trials   : {len(pruned_trials)}\")\n",
    "        print(f\"  # complete trials : {len(complete_trials)}\")\n",
    "        \n",
    "        print(f\"\\nBest trial          :\")\n",
    "        trial = study.best_trial\n",
    "        print(f\"  value             : {trial.value}\")\n",
    "        print(f\"  best_params = {trial.params}\\n\")\n",
    "    except:\n",
    "        print('\\nNo finished trials yet.')\n",
    "    return study"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/200_optuna.ipynb saved at 2022-11-09 13:16:31\n",
      "Correct notebook to script conversion! 😃\n",
      "Wednesday 09/11/22 13:16:34 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
